{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8evZeVpdJHR_"
      },
      "source": [
        "Licensed under the Apache License, Version 2.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8evZeVpdJHR_"
      },
      "source": [
        "# yobo x NGP, interactive training/rendering with Multiscope (v3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rgx2ZX6_I5n8"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import time\n",
        "\n",
        "import flax\n",
        "from flax.training import checkpoints\n",
        "import gin\n",
        "import jax\n",
        "import optax\n",
        "from jax import random\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import functools\n",
        "\n",
        "import mediapy as media\n",
        "from six.moves import reload_module\n",
        "from colabtools import adhoc_import, frontend\n",
        "from colabtools.interactive_widgets import ProgressIter\n",
        "\n",
        "port = multiscope.start_server()\n",
        "renderer = None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PC2faq4UJ6ni"
      },
      "outputs": [],
      "source": [
        "# Thanks to using adhoc_import, you can edit these files in Cider, then use\n",
        "# reload_module to update them w/o relaunching or even restarting the runtime.\n",
        "\n",
        "backend = jax.lib.xla_bridge.get_backend()\n",
        "for buf in backend.live_buffers():\n",
        "   buf.delete()\n",
        "\n",
        "gin.clear_config()\n",
        "gin.unlock_config()\n",
        "\n",
        "\n",
        "from google_research.yobo.internal import configs\n",
        "from google_research.yobo.internal import grid_utils\n",
        "from google_research.yobo.internal import camera_utils\n",
        "camera_utils = reload_module(camera_utils)\n",
        "from google_research.yobo.internal import datasets\n",
        "datasets = reload_module(datasets)\n",
        "from google_research.yobo.internal import math\n",
        "math = reload_module(math)\n",
        "from google_research.yobo.internal import render\n",
        "render = reload_module(render)\n",
        "from google_research.yobo.internal import coord\n",
        "coord = reload_module(coord)\n",
        "from google_research.yobo.internal import sample_net_utils\n",
        "sample_net_utils = reload_module(sample_net_utils)\n",
        "from google_research.yobo.internal.inverse_render import render_utils\n",
        "render_utils = reload_module(render_utils)\n",
        "from google_research.yobo.internal import models\n",
        "models = reload_module(models)\n",
        "from google_research.yobo.internal import sampling\n",
        "sampling = reload_module(sampling)\n",
        "from google_research.yobo.internal import geometry\n",
        "geometry = reload_module(geometry)\n",
        "from google_research.yobo.internal import integration\n",
        "integration = reload_module(integration)\n",
        "from google_research.yobo.internal import shading\n",
        "shading = reload_module(shading)\n",
        "from google_research.yobo.internal import material\n",
        "material = reload_module(material)\n",
        "from google_research.yobo.internal import stepfun\n",
        "stepfun = reload_module(stepfun)\n",
        "from google_research.yobo.internal import train_utils\n",
        "train_utils = reload_module(train_utils)\n",
        "from google_research.yobo.internal import loss_utils\n",
        "loss_utils = reload_module(loss_utils)\n",
        "from google_research.yobo.internal import utils\n",
        "utils = reload_module(utils)\n",
        "from google_research.yobo.internal import vis\n",
        "\n",
        "from google_research.yobo import multiscope_renderer\n",
        "multiscope_renderer = reload_module(multiscope_renderer)\n",
        "\n",
        "\n",
        "depot_base = ''\n",
        "config_base = depot_base + 'third_party/google_research/google_research/yobo/configs/'\n",
        "\n",
        "for d in [depot_base, config_base]:\n",
        "  if d not in gin.config._LOCATION_PREFIXES:\n",
        "    gin.add_config_file_search_path(d)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JTXcU_ZiPlb2"
      },
      "source": [
        "# Load configs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DKSqy7mtORw1"
      },
      "source": [
        "## Dataset config\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qFiFZ1iJOOVm"
      },
      "outputs": [],
      "source": [
        "# Dataset\n",
        "\n",
        "\n",
        "# Cache Checkpoint\n",
        "ckpt_dir = None\n",
        "\n",
        "# Cornelly\n",
        "\n",
        "# Lego small light\n",
        "\n",
        "# Scraperbikes\n",
        "\n",
        "# Configs\n",
        "config = None\n",
        "\n",
        "# Other settings\n",
        "use_material = True\n",
        "use_light_sampler = False\n",
        "\n",
        "optimize_cache = False or (not use_material)\n",
        "resample_material = True and use_material\n",
        "render_variate = True\n",
        "\n",
        "jitter_rays = 0 if not use_material else 0\n",
        "anneal_slope = 10.0 if ckpt_dir is None else 0.0\n",
        "\n",
        "num_secondary_samples = (8 if resample_material else 2)\n",
        "\n",
        "scale_fac = 4 if ckpt_dir is not None else 1\n",
        "scale_fac = 4 if use_light_sampler and not use_material else scale_fac\n",
        "\n",
        "batch_size = 65536 // scale_fac\n",
        "grad_accum_steps = 1\n",
        "max_steps = 25000 * scale_fac\n",
        "\n",
        "lr_init = 0.01 / scale_fac\n",
        "lr_final = 0.001 / scale_fac\n",
        "lr_delay_steps = 2500 * scale_fac\n",
        "\n",
        "lr_init_cache = (0.01 if ckpt_dir is None else 0.0005) / scale_fac\n",
        "lr_final_cache = (0.001 if ckpt_dir is None else 0.00005) / scale_fac\n",
        "lr_delay_steps_cache = (2500 if ckpt_dir is None else 0) * scale_fac\n",
        "\n",
        "lr_init_material = (0.005 if ckpt_dir is None else 0.0005) / scale_fac\n",
        "lr_final_material = (0.0005 if ckpt_dir is None else 0.00005) / scale_fac\n",
        "lr_delay_steps_material = (2500 if ckpt_dir is None else 0) * scale_fac\n",
        "\n",
        "lr_init_light = (0.001 if not use_material else 0.0005) / scale_fac\n",
        "lr_final_light = (0.0001 if not use_material else 0.00005) / scale_fac\n",
        "lr_delay_steps_light = (0 if not use_material else 0) * scale_fac\n",
        "\n",
        "extra_opt_params = {\n",
        "    'Cache': {\n",
        "        'lr_delay_steps': lr_delay_steps_cache,\n",
        "        'lr_final': lr_final_cache * optimize_cache,\n",
        "        'lr_init': lr_init_cache * optimize_cache,\n",
        "    },\n",
        "    'MaterialShader': {\n",
        "        'lr_delay_steps': lr_delay_steps_material,\n",
        "        'lr_final': lr_final_material,\n",
        "        'lr_init': lr_init_material,\n",
        "    },\n",
        "    'LightSampler': {\n",
        "        'lr_delay_steps': lr_delay_steps_light,\n",
        "        'lr_final': lr_final_light,\n",
        "        'lr_init': lr_init_light,\n",
        "    },\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gNYWYsN_ejb_"
      },
      "source": [
        "## Model config"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LaMDJOEueCKQ"
      },
      "outputs": [],
      "source": [
        "# Config\n",
        "# config_files = ['blender_ngp_yobo_material_cornelly.gin']\n",
        "# config_files = ['blender_ngp_yobo_material_lego.gin']\n",
        "# config_files = ['real_ngp_yobo_material_scraperbikes.gin']\n",
        "\n",
        "\n",
        "gin_configs = [config_base + f for f in config_files]\n",
        "\n",
        "gin_bindings = [\n",
        "  f'Config.ckpt_dir = \"{ckpt_dir}\"',\n",
        "  f'Config.max_steps = {max_steps * grad_accum_steps}',\n",
        "  f'Config.batch_size = {batch_size}',\n",
        "  f'Config.grad_accum_steps = {grad_accum_steps}',\n",
        "  f'Config.lr_init = {lr_init}',\n",
        "  f'Config.lr_final = {lr_final}',\n",
        "  f'Config.lr_delay_steps = {lr_delay_steps}',\n",
        "  f'Config.extra_opt_params = {extra_opt_params}',\n",
        "  f'ProposalVolumeSampler.anneal_slope = {anneal_slope}',\n",
        "  f'MaterialModel.use_material = {use_material}',\n",
        "  f'MaterialModel.use_light_sampler = {use_light_sampler}',\n",
        "  f'MaterialModel.resample_material = {resample_material}',\n",
        "  f'MaterialModel.render_variate = {render_variate}',\n",
        "  f'MaterialMLP.num_secondary_samples = {num_secondary_samples}',\n",
        "  f'MaterialMLP.render_num_secondary_samples = {num_secondary_samples}',\n",
        "]\n",
        "\n",
        "gin.clear_config()\n",
        "gin.parse_config_files_and_bindings(gin_configs, gin_bindings, skip_unknown=True)\n",
        "config = configs.Config()\n",
        "print(gin.config_str())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kDHvjufYeLgn"
      },
      "source": [
        "# Load dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P5Upmyi_eGn4"
      },
      "outputs": [],
      "source": [
        "# Load dataset.\n",
        "dataset = datasets.load_dataset('train', config.data_dir, config)\n",
        "multiscope_renderer.plot_poses(dataset.camtoworlds, eps=.05)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "anFUaJjxD9XQ"
      },
      "outputs": [],
      "source": [
        "print(jnp.max(dataset.images))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C3gQLR9x8ZkQ"
      },
      "source": [
        "# Load model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GLQB65CCMA-L"
      },
      "outputs": [],
      "source": [
        "# Load config again\n",
        "gin.clear_config()\n",
        "gin.parse_config_files_and_bindings(gin_configs, gin_bindings, skip_unknown=True)\n",
        "config = configs.Config()\n",
        "\n",
        "# Create model and training functions.\n",
        "# dataset.reload_mesh(config)\n",
        "model, train_state, render_eval_pfn, train_pstep, _ = train_utils.setup_model(\n",
        "    config, random.PRNGKey(np.random.randint(1000)), dataset\n",
        ")\n",
        "\n",
        "# Restore cache checkpoint\n",
        "train_state = train_utils.restore_partial_checkpoint(\n",
        "    config, train_state,\n",
        "    prefixes=(\n",
        "        ['Cache'] + (['LightSampler'] if use_material and use_light_sampler else [])\n",
        "    ),\n",
        "    replace_dict={\n",
        "        'Cache': 'Cache',\n",
        "        'LightSampler': 'LightSampler',\n",
        "    }\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dCPyJRLk9W0F"
      },
      "outputs": [],
      "source": [
        "# (Optionally) restore all\n",
        "train_state = checkpoints.restore_checkpoint(config.ckpt_dir, train_state)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5X5Or1yRPyle"
      },
      "source": [
        "# Training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zbeT40CzPfZh"
      },
      "source": [
        "## Model Training Loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DB7OWk-LMJ7H"
      },
      "outputs": [],
      "source": [
        "## For the model\n",
        "\n",
        "# Create the renderer.\n",
        "multiscope.reset()\n",
        "spl = multiscope_renderer.Spliner() if renderer is None else renderer.controller.spl\n",
        "\n",
        "scale_factor = 2 if 'llff' not in config.dataset_loader else 8\n",
        "width = ((dataset.width // scale_factor) // 16) * 16\n",
        "\n",
        "renderer = multiscope_renderer.MultiscopeRenderer(dataset, config, model, train_state, train_pstep, spl, hwf_init=(\n",
        "    width,\n",
        "    width,\n",
        "    (float(width) / dataset.width) / dataset.pixtocams[0, 0, 0]\n",
        "  )\n",
        ")\n",
        "\n",
        "# Uncomment this if you want to start training right away:\n",
        "renderer.training = True\n",
        "\n",
        "# Run one step to jit the render function.\n",
        "renderer.step()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Cw6GrcpoOOm"
      },
      "outputs": [],
      "source": [
        "if True:\n",
        "  # Open the page\n",
        "  frontend.OpenUrl(multiscope.get_dashboard_url(port))\n",
        "\n",
        "  # Set training to true\n",
        "  renderer.training = True\n",
        "\n",
        "  # Run the renderer indefinitely.\n",
        "  while True:\n",
        "    renderer.step()\n",
        "    #time.sleep(.01)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9wznQ_F2c7oy"
      },
      "source": [
        "# Checkpoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k0YOvR29zDtk"
      },
      "outputs": [],
      "source": [
        "# Save checkpoint\n",
        "from datetime import date\n",
        "today = date.today()\n",
        "scene_name = config.data_dir.split('/')[-1]\n",
        "\n",
        "if 'tiny' in config_files[0]:\n",
        "  model_suffix = 'tiny'\n",
        "elif 'small' in config_files[0]:\n",
        "  model_suffix = 'small'\n",
        "else:\n",
        "  model_suffix = 'large'\n",
        "\n",
        "ckpt_dir = f'{scene_name}/{today}/{model_suffix}' + ('/light_sampler' if use_light_sampler else '') + ('/material' if use_material else '')\n",
        "\n",
        "train_state = flax.jax_utils.unreplicate(renderer.state)\n",
        "train_step = flax.jax_utils.unreplicate(renderer.state.step)\n",
        "checkpoints.save_checkpoint(ckpt_dir=ckpt_dir, target=train_state, step=train_step, overwrite=True)"
      ]
    }
  ],
  "nbformat": 4,
  "nbformat_minor": 0
}
